{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TripletMarginLossMNIST.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "lMde13VLDjiH",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 225
        },
        "outputId": "7e2d385b-f008-4b49-9ad1-466c9ff4d3b9"
      },
      "source": [
        "!pip install pytorch-metric-learning\n",
        "!pip install faiss-gpu"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pytorch-metric-learning in /usr/local/lib/python3.6/dist-packages (0.9.90)\n",
            "Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (1.6.0+cu101)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (1.18.5)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (0.22.2.post1)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (0.7.0+cu101)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (4.41.1)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.6.0->pytorch-metric-learning) (0.16.0)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->pytorch-metric-learning) (0.16.0)\n",
            "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->pytorch-metric-learning) (1.4.1)\n",
            "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->pytorch-metric-learning) (7.0.0)\n",
            "Requirement already satisfied: faiss-gpu in /usr/local/lib/python3.6/dist-packages (1.6.3)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from faiss-gpu) (1.18.5)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GJ_L0TrTDnEA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 277
        },
        "outputId": "cd926fd9-b75f-4bd8-e531-4d13d4fcf01d"
      },
      "source": [
        "from pytorch_metric_learning import losses, miners, distances, reducers, testers\n",
        "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n",
        "### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### \n",
        "from torchvision import datasets\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torchvision import datasets, transforms\n",
        "import numpy as np\n",
        "\n",
        "### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### \n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Net, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
        "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
        "        self.dropout1 = nn.Dropout2d(0.25)\n",
        "        self.dropout2 = nn.Dropout2d(0.5)\n",
        "        self.fc1 = nn.Linear(9216, 128)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv1(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.conv2(x)\n",
        "        x = F.relu(x)\n",
        "        x = F.max_pool2d(x, 2)\n",
        "        x = self.dropout1(x)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = self.fc1(x)\n",
        "        return x\n",
        "\n",
        "### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### \n",
        "def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):\n",
        "    model.train()\n",
        "    for batch_idx, (data, labels) in enumerate(train_loader):\n",
        "        data, labels = data.to(device), labels.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        embeddings = model(data)\n",
        "        indices_tuple = mining_func(embeddings, labels)\n",
        "        loss = loss_func(embeddings, labels, indices_tuple)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        if batch_idx % 20 == 0:\n",
        "            print(\"Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}\".format(epoch, batch_idx, loss, mining_func.num_triplets))\n",
        "\n",
        "### convenient function from pytorch-metric-learning ###\n",
        "def get_all_embeddings(dataset, model):\n",
        "    tester = testers.BaseTester()\n",
        "    return tester.get_all_embeddings(dataset, model)\n",
        "\n",
        "### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###\n",
        "def test(dataset, model, accuracy_calculator):\n",
        "    embeddings, labels = get_all_embeddings(dataset, model)\n",
        "    print(\"Computing accuracy\")\n",
        "    accuracies = accuracy_calculator.get_accuracy(embeddings, \n",
        "                                                embeddings,\n",
        "                                                np.squeeze(labels),\n",
        "                                                np.squeeze(labels),\n",
        "                                                True)\n",
        "    print(\"Test set accuracy (MAP@10) = {}\".format(accuracies[\"mean_average_precision_at_r\"]))\n",
        "\n",
        "device = torch.device(\"cuda\")\n",
        "\n",
        "transform = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.1307,), (0.3081,))\n",
        "    ])\n",
        "\n",
        "batch_size = 256\n",
        "\n",
        "dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)\n",
        "dataset2 = datasets.MNIST('.', train=False, transform=transform)\n",
        "train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)\n",
        "test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)\n",
        "\n",
        "model = Net().to(device)\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
        "num_epochs = 1\n",
        "\n",
        "\n",
        "### pytorch-metric-learning stuff ###\n",
        "distance = distances.CosineSimilarity()\n",
        "reducer = reducers.ThresholdReducer(low = 0)\n",
        "loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)\n",
        "mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = \"semihard\")\n",
        "accuracy_calculator = AccuracyCalculator(include = (\"mean_average_precision_at_r\",), k = 10)\n",
        "### pytorch-metric-learning stuff ###\n",
        "\n",
        "\n",
        "for epoch in range(1, num_epochs+1):\n",
        "    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)\n",
        "    test(dataset2, model, accuracy_calculator)\n",
        "\n"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1 Iteration 0: Loss = 0.10368159413337708, Number of mined triplets = 828800\n",
            "Epoch 1 Iteration 20: Loss = 0.09451039880514145, Number of mined triplets = 117172\n",
            "Epoch 1 Iteration 40: Loss = 0.0870940089225769, Number of mined triplets = 75505\n",
            "Epoch 1 Iteration 60: Loss = 0.08687727153301239, Number of mined triplets = 63229\n",
            "Epoch 1 Iteration 80: Loss = 0.08603870868682861, Number of mined triplets = 39559\n",
            "Epoch 1 Iteration 100: Loss = 0.0871090441942215, Number of mined triplets = 49315\n",
            "Epoch 1 Iteration 120: Loss = 0.08383794873952866, Number of mined triplets = 38464\n",
            "Epoch 1 Iteration 140: Loss = 0.08661019802093506, Number of mined triplets = 28360\n",
            "Epoch 1 Iteration 160: Loss = 0.07955317199230194, Number of mined triplets = 17407\n",
            "Epoch 1 Iteration 180: Loss = 0.0819648876786232, Number of mined triplets = 26285\n",
            "Epoch 1 Iteration 200: Loss = 0.08189083635807037, Number of mined triplets = 22461\n",
            "Epoch 1 Iteration 220: Loss = 0.0827690064907074, Number of mined triplets = 23761\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 313/313 [00:03<00:00, 88.29it/s]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Computing accuracy\n",
            "Test set accuracy (MAP@10) = 0.9738574404761905\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stderr"
        }
      ]
    }
  ]
}